{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"11_Convolutional_Neural_Networks","provenance":[{"file_id":"https://github.com/GokuMohandas/MadeWithML/blob/main/notebooks/11_Convolutional_Neural_Networks.ipynb","timestamp":1608240226044},{"file_id":"https://github.com/GokuMohandas/MadeWithML/blob/main/notebooks/11_Convolutional_Neural_Networks.ipynb","timestamp":1584545223147},{"file_id":"https://github.com/GokuMohandas/MadeWithML/blob/main/notebooks/11_Convolutional_Neural_Networks.ipynb","timestamp":1583211224146}],"collapsed_sections":[],"toc_visible":true},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"kkAvSg_-4dPz"},"source":["<div align=\"center\">\n","<h1><img width=\"30\" src=\"https://madewithml.com/static/images/rounded_logo.png\">&nbsp;<a href=\"https://madewithml.com/\">Made With ML</a></h1>\n","Applied ML · MLOps · Production\n","<br>\n","Join 30K+ developers in learning how to responsibly <a href=\"https://madewithml.com/about/\">deliver value</a> with ML.\n","    <br>\n","</div>\n","\n","<br>\n","\n","<div align=\"center\">\n","    <a target=\"_blank\" href=\"https://newsletter.madewithml.com\"><img src=\"https://img.shields.io/badge/Subscribe-30K-brightgreen\"></a>&nbsp;\n","    <a target=\"_blank\" href=\"https://github.com/GokuMohandas/MadeWithML\"><img src=\"https://img.shields.io/github/stars/GokuMohandas/MadeWithML.svg?style=social&label=Star\"></a>&nbsp;\n","    <a target=\"_blank\" href=\"https://www.linkedin.com/in/goku\"><img src=\"https://img.shields.io/badge/style--5eba00.svg?label=LinkedIn&logo=linkedin&style=social\"></a>&nbsp;\n","    <a target=\"_blank\" href=\"https://twitter.com/GokuMohandas\"><img src=\"https://img.shields.io/twitter/follow/GokuMohandas.svg?label=Follow&style=social\"></a>\n","    <br>\n","    🔥&nbsp; Among the <a href=\"https://github.com/topics/deep-learning\" target=\"_blank\">top ML</a> repositories on GitHub\n","</div>\n","\n","<br>\n","<hr>"]},{"cell_type":"markdown","metadata":{"id":"eTdCMVl9YAXw"},"source":["# Convolutional Neural Networks (CNN)\n","\n","In this lesson we will explore the basics of Convolutional Neural Networks (CNNs) applied to text for natural language processing (NLP) tasks."]},{"cell_type":"markdown","metadata":{"id":"xuabAj4PYj57"},"source":["<div align=\"left\">\n","<a target=\"_blank\" href=\"https://madewithml.com/courses/basics/convolutional-neural-networks/\"><img src=\"https://img.shields.io/badge/📖 Read-blog post-9cf\"></a>&nbsp;\n","<a href=\"https://github.com/GokuMohandas/MadeWithML/blob/main/notebooks/11_Convolutional_Neural_Networks.ipynb\" role=\"button\"><img src=\"https://img.shields.io/static/v1?label=&amp;message=View%20On%20GitHub&amp;color=586069&amp;logo=github&amp;labelColor=2f363d\"></a>&nbsp;\n","<a href=\"https://colab.research.google.com/github/GokuMohandas/MadeWithML/blob/main/notebooks/11_Convolutional_Neural_Networks.ipynb\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"></a>\n","</div>"]},{"cell_type":"markdown","metadata":{"id":"Xz15aMLj7gPX"},"source":["# Overview"]},{"cell_type":"markdown","metadata":{"id":"Ym2pswjn5GNw"},"source":["At the core of CNNs are filters (aka weights, kernels, etc.) which convolve (slide) across our input to extract relevant features. The filters are initialized randomly but learn to act as feature extractors via parameter sharing."]},{"cell_type":"markdown","metadata":{"id":"TPGoKi5k4j5y"},"source":["<div align=\"left\">\n","<img src=\"https://raw.githubusercontent.com/GokuMohandas/MadeWithML/main/images/basics/cnn/convolution.gif\" width=\"500\">\n","</div>"]},{"cell_type":"markdown","metadata":{"id":"JqxyljU18hvt"},"source":["* **Objective:**  Extract meaningful spatial substructure from encoded data.\n","* **Advantages:** \n","  * Small number of weights (shared)\n","  * Parallelizable\n","  * Detects spatial substrcutures (feature extractors)\n","  * [Interpretability](https://arxiv.org/abs/1312.6034) via filters\n","  * Can be used for processing in images, text, time-series, etc.\n","* **Disadvantages:**\n","  * Many hyperparameters (kernel size, strides, etc.) to tune.\n","* **Miscellaneous:** \n","  * Lot's of deep CNN architectures constantly updated for SOTA performance.\n","  * Very popular feature extractor that's usually prepended onto other architectures.\n"]},{"cell_type":"markdown","metadata":{"id":"6GLD2nXvo-r4"},"source":["# Set up"]},{"cell_type":"code","metadata":{"id":"y3qKSoEe57na"},"source":["import numpy as np\n","import pandas as pd\n","import random\n","import torch\n","import torch.nn as nn"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"ZMLLoU5htnxT"},"source":["SEED = 1234"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"xK1YEDDotQ3V"},"source":["def set_seeds(seed=1234):\n","    \"\"\"Set seeds for reproducibility.\"\"\"\n","    np.random.seed(seed)\n","    random.seed(seed)\n","    torch.manual_seed(seed)\n","    torch.cuda.manual_seed(seed)\n","    torch.cuda.manual_seed_all(seed) # multi-GPU"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"XU1TUOBitpaR"},"source":["# Set seeds for reproducibility\n","set_seeds(seed=SEED)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"86uoTigW0ylA","executionInfo":{"status":"ok","timestamp":1608329398628,"user_tz":420,"elapsed":3911,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"bd4c1448-851c-4415-c8ac-09d76183fc54"},"source":["# Set device\n","cuda = True\n","device = torch.device('cuda' if (\n","    torch.cuda.is_available() and cuda) else 'cpu')\n","torch.set_default_tensor_type('torch.FloatTensor')\n","if device.type == 'cuda':\n","    torch.set_default_tensor_type('torch.cuda.FloatTensor')\n","print (device)"],"execution_count":null,"outputs":[{"output_type":"stream","text":["cuda\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"c69z9wpJ56nE"},"source":["## Load data"]},{"cell_type":"markdown","metadata":{"id":"2V_nEp5G58M0"},"source":["We will download the [AG News dataset](http://www.di.unipi.it/~gulli/AG_corpus_of_news_articles.html), which consists of 120K text samples from 4 unique classes (`Business`, `Sci/Tech`, `Sports`, `World`)"]},{"cell_type":"code","metadata":{"id":"cdjdvnOGrsZP","colab":{"base_uri":"https://localhost:8080/","height":204},"executionInfo":{"status":"ok","timestamp":1608329399051,"user_tz":420,"elapsed":4304,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"53c06734-a2a9-4928-f080-49b00e9f150c"},"source":["# Load data\n","url = \"https://raw.githubusercontent.com/GokuMohandas/MadeWithML/main/datasets/news.csv\"\n","df = pd.read_csv(url, header=0) # load\n","df = df.sample(frac=1).reset_index(drop=True) # shuffle\n","df.head()"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/html":["<div>\n","<style scoped>\n","    .dataframe tbody tr th:only-of-type {\n","        vertical-align: middle;\n","    }\n","\n","    .dataframe tbody tr th {\n","        vertical-align: top;\n","    }\n","\n","    .dataframe thead th {\n","        text-align: right;\n","    }\n","</style>\n","<table border=\"1\" class=\"dataframe\">\n","  <thead>\n","    <tr style=\"text-align: right;\">\n","      <th></th>\n","      <th>title</th>\n","      <th>category</th>\n","    </tr>\n","  </thead>\n","  <tbody>\n","    <tr>\n","      <th>0</th>\n","      <td>Sharon Accepts Plan to Reduce Gaza Army Operat...</td>\n","      <td>World</td>\n","    </tr>\n","    <tr>\n","      <th>1</th>\n","      <td>Internet Key Battleground in Wildlife Crime Fight</td>\n","      <td>Sci/Tech</td>\n","    </tr>\n","    <tr>\n","      <th>2</th>\n","      <td>July Durable Good Orders Rise 1.7 Percent</td>\n","      <td>Business</td>\n","    </tr>\n","    <tr>\n","      <th>3</th>\n","      <td>Growing Signs of a Slowing on Wall Street</td>\n","      <td>Business</td>\n","    </tr>\n","    <tr>\n","      <th>4</th>\n","      <td>The New Faces of Reality TV</td>\n","      <td>World</td>\n","    </tr>\n","  </tbody>\n","</table>\n","</div>"],"text/plain":["                                               title  category\n","0  Sharon Accepts Plan to Reduce Gaza Army Operat...     World\n","1  Internet Key Battleground in Wildlife Crime Fight  Sci/Tech\n","2          July Durable Good Orders Rise 1.7 Percent  Business\n","3          Growing Signs of a Slowing on Wall Street  Business\n","4                        The New Faces of Reality TV     World"]},"metadata":{"tags":[]},"execution_count":6}]},{"cell_type":"markdown","metadata":{"id":"RQUDEgwloxhF"},"source":["## Preprocessing"]},{"cell_type":"markdown","metadata":{"id":"2QKp1TyPpBKG"},"source":["We're going to clean up our input data first by doing operations such as lower text, removing stop (filler) words, filters using regular expressions, etc."]},{"cell_type":"code","metadata":{"id":"S-Mv_g0cowkR"},"source":["import nltk\n","from nltk.corpus import stopwords\n","from nltk.stem import PorterStemmer\n","import re"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"K0DwdEzxownP","executionInfo":{"status":"ok","timestamp":1608329400344,"user_tz":420,"elapsed":5539,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"078c0da6-c18e-416f-900f-01336eb531d1"},"source":["nltk.download('stopwords')\n","STOPWORDS = stopwords.words('english')\n","print (STOPWORDS[:5])\n","porter = PorterStemmer()"],"execution_count":null,"outputs":[{"output_type":"stream","text":["[nltk_data] Downloading package stopwords to /root/nltk_data...\n","[nltk_data]   Unzipping corpora/stopwords.zip.\n","['i', 'me', 'my', 'myself', 'we']\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"w1Yyrsp0owqk"},"source":["def preprocess(text, stopwords=STOPWORDS):\n","    \"\"\"Conditional preprocessing on our text unique to our task.\"\"\"\n","    # Lower\n","    text = text.lower()\n","\n","    # Remove stopwords\n","    pattern = re.compile(r'\\b(' + r'|'.join(stopwords) + r')\\b\\s*')\n","    text = pattern.sub('', text)\n","\n","    # Remove words in paranthesis\n","    text = re.sub(r'\\([^)]*\\)', '', text)\n","\n","    # Spacing and filters\n","    text = re.sub(r\"([-;;.,!?<=>])\", r\" \\1 \", text)\n","    text = re.sub('[^A-Za-z0-9]+', ' ', text) # remove non alphanumeric chars\n","    text = re.sub(' +', ' ', text)  # remove multiple spaces\n","    text = text.strip()\n","\n","    return text"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":35},"id":"zvcSy1wSowtC","executionInfo":{"status":"ok","timestamp":1608329400349,"user_tz":420,"elapsed":5482,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"9d8a1d6a-bd53-4824-b0b4-7d029a83baa2"},"source":["# Sample\n","text = \"Great week for the NYSE!\"\n","preprocess(text=text)"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"application/vnd.google.colaboratory.intrinsic+json":{"type":"string"},"text/plain":["'great week nyse'"]},"metadata":{"tags":[]},"execution_count":10}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"czjQ3lrrowwh","executionInfo":{"status":"ok","timestamp":1608329402722,"user_tz":420,"elapsed":7810,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"6f554009-55ba-4763-fdda-f942bc72e7f6"},"source":["# Apply to dataframe\n","preprocessed_df = df.copy()\n","preprocessed_df.title = preprocessed_df.title.apply(preprocess)\n","print (f\"{df.title.values[0]}\\n\\n{preprocessed_df.title.values[0]}\")"],"execution_count":null,"outputs":[{"output_type":"stream","text":["Sharon Accepts Plan to Reduce Gaza Army Operation, Haaretz Says\n","\n","sharon accepts plan reduce gaza army operation haaretz says\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"h5ZPPgPApLL1"},"source":["> If you have preprocessing steps like standardization, etc. that are calculated, you need to separate the training and test set first before applying those operations. This is because we cannot apply any knowledge gained from the test set accidentally (data leak) during preprocessing/training. However for global preprocessing steps like the function above where we aren't learning anything from the data itself, we can perform before splitting the data."]},{"cell_type":"markdown","metadata":{"id":"zgStr_fDpMU4"},"source":["## Split data"]},{"cell_type":"code","metadata":{"id":"287RCymQowyV"},"source":["import collections\n","from sklearn.model_selection import train_test_split"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"BU--xNveow11"},"source":["TRAIN_SIZE = 0.7\n","VAL_SIZE = 0.15\n","TEST_SIZE = 0.15"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"OyQJ6x0vpP7c"},"source":["def train_val_test_split(X, y, train_size):\n","    \"\"\"Split dataset into data splits.\"\"\"\n","    X_train, X_, y_train, y_ = train_test_split(X, y, train_size=TRAIN_SIZE, stratify=y)\n","    X_val, X_test, y_val, y_test = train_test_split(X_, y_, train_size=0.5, stratify=y_)\n","    return X_train, X_val, X_test, y_train, y_val, y_test"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"KKvtQWxepP-C"},"source":["# Data\n","X = preprocessed_df[\"title\"].values\n","y = preprocessed_df[\"category\"].values"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"bnjaba3CpQAm","executionInfo":{"status":"ok","timestamp":1608329402732,"user_tz":420,"elapsed":7720,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"3ce08f5e-6be3-40cb-ab9c-1cbf28d37f05"},"source":["# Create data splits\n","X_train, X_val, X_test, y_train, y_val, y_test = train_val_test_split(\n","    X=X, y=y, train_size=TRAIN_SIZE)\n","print (f\"X_train: {X_train.shape}, y_train: {y_train.shape}\")\n","print (f\"X_val: {X_val.shape}, y_val: {y_val.shape}\")\n","print (f\"X_test: {X_test.shape}, y_test: {y_test.shape}\")\n","print (f\"Sample point: {X_train[0]} → {y_train[0]}\")"],"execution_count":null,"outputs":[{"output_type":"stream","text":["X_train: (84000,), y_train: (84000,)\n","X_val: (18000,), y_val: (18000,)\n","X_test: (18000,), y_test: (18000,)\n","Sample point: china battles north korea nuclear talks → World\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"6JJBLk0rpVrX"},"source":["## LabelEncoder"]},{"cell_type":"markdown","metadata":{"id":"gxCbJ294pYfd"},"source":["Next we'll define a `LabelEncoder` to encode our text labels into unique indices"]},{"cell_type":"code","metadata":{"id":"_voCWJ41pQIH"},"source":["import itertools"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"p9Ya97mMp0uR"},"source":["class LabelEncoder(object):\n","    \"\"\"Label encoder for tag labels.\"\"\"\n","    def __init__(self, class_to_index={}):\n","        self.class_to_index = class_to_index\n","        self.index_to_class = {v: k for k, v in self.class_to_index.items()}\n","        self.classes = list(self.class_to_index.keys())\n","\n","    def __len__(self):\n","        return len(self.class_to_index)\n","\n","    def __str__(self):\n","        return f\"<LabelEncoder(num_classes={len(self)})>\"\n","\n","    def fit(self, y):\n","        classes = np.unique(y)\n","        for i, class_ in enumerate(classes):\n","            self.class_to_index[class_] = i\n","        self.index_to_class = {v: k for k, v in self.class_to_index.items()}\n","        self.classes = list(self.class_to_index.keys())\n","        return self\n","\n","    def encode(self, y):\n","        encoded = np.zeros((len(y)), dtype=int)\n","        for i, item in enumerate(y):\n","            encoded[i] = self.class_to_index[item]\n","        return encoded\n","\n","    def decode(self, y):\n","        classes = []\n","        for i, item in enumerate(y):\n","            classes.append(self.index_to_class[item])\n","        return classes\n","\n","    def save(self, fp):\n","        with open(fp, 'w') as fp:\n","            contents = {'class_to_index': self.class_to_index}\n","            json.dump(contents, fp, indent=4, sort_keys=False)\n","\n","    @classmethod\n","    def load(cls, fp):\n","        with open(fp, 'r') as fp:\n","            kwargs = json.load(fp=fp)\n","        return cls(**kwargs)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"WIFZYOLzp0ws","executionInfo":{"status":"ok","timestamp":1608329402988,"user_tz":420,"elapsed":7917,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"29255ac5-cec4-43a2-debf-a71e84cefabe"},"source":["# Encode\n","label_encoder = LabelEncoder()\n","label_encoder.fit(y_train)\n","NUM_CLASSES = len(label_encoder)\n","label_encoder.class_to_index"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["{'Business': 0, 'Sci/Tech': 1, 'Sports': 2, 'World': 3}"]},"metadata":{"tags":[]},"execution_count":19}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"MUMaapBCp0zq","executionInfo":{"status":"ok","timestamp":1608329402990,"user_tz":420,"elapsed":7888,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"d3aa414a-477f-448e-e0e5-18ae41c1327b"},"source":["# Convert labels to tokens\n","print (f\"y_train[0]: {y_train[0]}\")\n","y_train = label_encoder.encode(y_train)\n","y_val = label_encoder.encode(y_val)\n","y_test = label_encoder.encode(y_test)\n","print (f\"y_train[0]: {y_train[0]}\")"],"execution_count":null,"outputs":[{"output_type":"stream","text":["y_train[0]: World\n","y_train[0]: 3\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"U2jYqr5Yp015","executionInfo":{"status":"ok","timestamp":1608329402991,"user_tz":420,"elapsed":7866,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"535ccae3-a0ed-42f1-f795-7b71bdf47071"},"source":["# Class weights\n","counts = np.bincount(y_train)\n","class_weights = {i: 1.0/count for i, count in enumerate(counts)}\n","print (f\"counts: {counts}\\nweights: {class_weights}\")"],"execution_count":null,"outputs":[{"output_type":"stream","text":["counts: [21000 21000 21000 21000]\n","weights: {0: 4.761904761904762e-05, 1: 4.761904761904762e-05, 2: 4.761904761904762e-05, 3: 4.761904761904762e-05}\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"8ULHakYnp7fA"},"source":["# Tokenizer"]},{"cell_type":"markdown","metadata":{"id":"FXEyU6evp-4P"},"source":["Our input data is text and we can't feed it directly to our models. So, we'll define a `Tokenizer` to convert our text input data into token indices. This means that every token (we can decide what a token is char, word, sub-word, etc.) is mapped to a unique index which allows us to represent our text as an array of indices. "]},{"cell_type":"code","metadata":{"id":"DxSzCrYJpQKq"},"source":["import json\n","from collections import Counter\n","from more_itertools import take"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"ymD8m7Lfow4J"},"source":["class Tokenizer(object):\n","    def __init__(self, char_level, num_tokens=None, \n","                 pad_token='<PAD>', oov_token='<UNK>',\n","                 token_to_index=None):\n","        self.char_level = char_level\n","        self.separator = '' if self.char_level else ' '\n","        if num_tokens: num_tokens -= 2 # pad + unk tokens\n","        self.num_tokens = num_tokens\n","        self.pad_token = pad_token\n","        self.oov_token = oov_token\n","        if not token_to_index:\n","            token_to_index = {pad_token: 0, oov_token: 1}\n","        self.token_to_index = token_to_index\n","        self.index_to_token = {v: k for k, v in self.token_to_index.items()}\n","\n","    def __len__(self):\n","        return len(self.token_to_index)\n","\n","    def __str__(self):\n","        return f\"<Tokenizer(num_tokens={len(self)})>\"\n","\n","    def fit_on_texts(self, texts):\n","        if not self.char_level:\n","            texts = [text.split(\" \") for text in texts]\n","        all_tokens = [token for text in texts for token in text]\n","        counts = Counter(all_tokens).most_common(self.num_tokens)\n","        self.min_token_freq = counts[-1][1]\n","        for token, count in counts:\n","            index = len(self)\n","            self.token_to_index[token] = index\n","            self.index_to_token[index] = token\n","        return self\n","\n","    def texts_to_sequences(self, texts):\n","        sequences = []\n","        for text in texts:\n","            if not self.char_level:\n","                text = text.split(' ')\n","            sequence = []\n","            for token in text:\n","                sequence.append(self.token_to_index.get(\n","                    token, self.token_to_index[self.oov_token]))\n","            sequences.append(np.asarray(sequence))\n","        return sequences\n","\n","    def sequences_to_texts(self, sequences):\n","        texts = []\n","        for sequence in sequences:\n","            text = []\n","            for index in sequence:\n","                text.append(self.index_to_token.get(index, self.oov_token))\n","            texts.append(self.separator.join([token for token in text]))\n","        return texts\n","\n","    def save(self, fp):\n","        with open(fp, 'w') as fp:\n","            contents = {\n","                'char_level': self.char_level,\n","                'oov_token': self.oov_token,\n","                'token_to_index': self.token_to_index\n","            }\n","            json.dump(contents, fp, indent=4, sort_keys=False)\n","\n","    @classmethod\n","    def load(cls, fp):\n","        with open(fp, 'r') as fp:\n","            kwargs = json.load(fp=fp)\n","        return cls(**kwargs)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"UUNjbGizqb1v"},"source":["We're going to restrict the number of tokens in our `Tokenizer` to the top 500 most frequent tokens (stop words already removed) because the full vocabulary size (~30K) is too large to run on Google Colab notebooks.\n","\n","> It's important that we only fit using our train data split because during inference, our model will not always know every token so it's important to replicate that scenario with our validation and test splits as well."]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"2HWkNc94qVH8","executionInfo":{"status":"ok","timestamp":1608329403510,"user_tz":420,"elapsed":8326,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"b9064379-200b-477d-b617-61ae91eee41a"},"source":["# Tokenize\n","tokenizer = Tokenizer(char_level=False, num_tokens=500)\n","tokenizer.fit_on_texts(texts=X_train)\n","VOCAB_SIZE = len(tokenizer)\n","print (tokenizer)"],"execution_count":null,"outputs":[{"output_type":"stream","text":["<Tokenizer(num_tokens=500)>\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"9x_RK5xBqVKX","executionInfo":{"status":"ok","timestamp":1608329403512,"user_tz":420,"elapsed":8300,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"00d035cd-4f54-4860-fd38-431ef8f55099"},"source":["# Sample of tokens\n","print (take(5, tokenizer.token_to_index.items()))\n","print (f\"least freq token's freq: {tokenizer.min_token_freq}\") # use this to adjust num_tokens"],"execution_count":null,"outputs":[{"output_type":"stream","text":["[('<PAD>', 0), ('<UNK>', 1), ('39', 2), ('b', 3), ('gt', 4)]\n","least freq token's freq: 166\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"7CQYu3DFqVNR","executionInfo":{"status":"ok","timestamp":1608329403709,"user_tz":420,"elapsed":8471,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"d6c6baee-dbeb-4b82-85b9-682e65eb4764"},"source":["# Convert texts to sequences of indices\n","X_train = tokenizer.texts_to_sequences(X_train)\n","X_val = tokenizer.texts_to_sequences(X_val)\n","X_test = tokenizer.texts_to_sequences(X_test)\n","preprocessed_text = tokenizer.sequences_to_texts([X_train[0]])[0]\n","print (\"Text to indices:\\n\"\n","    f\"  (preprocessed) → {preprocessed_text}\\n\"\n","    f\"  (tokenized) → {X_train[0]}\")"],"execution_count":null,"outputs":[{"output_type":"stream","text":["Text to indices:\n","  (preprocessed) → china <UNK> north korea nuclear talks\n","  (tokenized) → [ 16   1 285 142 114  24]\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"X_BtTeojq4UU"},"source":["# One-hot encoding"]},{"cell_type":"markdown","metadata":{"id":"XgB24lwpq7aA"},"source":["One-hot encoding creates a binary column for each unique value for the feature we're trying to map.  All of the values in each token's array will be 0 except at the index that this specific token is represented by.\n","\n","There are 5 words in the vocabulary:\n","```json\n","{\n","    \"a\": 0,\n","    \"e\": 1,\n","    \"i\": 2,\n","    \"o\": 3,\n","    \"u\": 4\n","}\n","```\n","\n","Then the text `aou` would be represented by:\n","```python\n","[[1. 0. 0. 0. 0.]\n"," [0. 0. 0. 1. 0.]\n"," [0. 0. 0. 0. 1.]]\n","```\n","\n","One-hot encoding allows us to represent our data in a way that our models can process the data and isn't biased by the actual value of the token (ex. if your labels were actual numbers). \n","\n","> We have already applied one-hot encoding in the previous lessons when we encoded our labels. Each label was represented by a unique index but when determining loss, we effectively use it's one hot representation and compared it to the predicted probability distribution. We never explicitly wrote this out since all of our previous tasks were multi-class which means every input had just one output class, so the 0s didn't affect the loss (though it did matter during back propagation)."]},{"cell_type":"code","metadata":{"id":"tYqX76Lpq4aK"},"source":["def to_categorical(seq, num_classes):\n","    \"\"\"One-hot encode a sequence of tokens.\"\"\"\n","    one_hot = np.zeros((len(seq), num_classes))\n","    for i, item in enumerate(seq):\n","        one_hot[i, item] = 1.\n","    return one_hot"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"Nze7Wq-sq4cv","executionInfo":{"status":"ok","timestamp":1608329403713,"user_tz":420,"elapsed":8433,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"cbb0ee0a-a054-4f01-d11a-f3d8b19e541b"},"source":["# One-hot encoding\n","print (X_train[0])\n","print (len(X_train[0]))\n","cat = to_categorical(seq=X_train[0], num_classes=len(tokenizer))\n","print (cat)\n","print (cat.shape)"],"execution_count":null,"outputs":[{"output_type":"stream","text":["[ 16   1 285 142 114  24]\n","6\n","[[0. 0. 0. ... 0. 0. 0.]\n"," [0. 1. 0. ... 0. 0. 0.]\n"," [0. 0. 0. ... 0. 0. 0.]\n"," [0. 0. 0. ... 0. 0. 0.]\n"," [0. 0. 0. ... 0. 0. 0.]\n"," [0. 0. 0. ... 0. 0. 0.]]\n","(6, 500)\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"FHAIYQLMq4fq"},"source":["# Convert tokens to one-hot\n","vocab_size = len(tokenizer)\n","X_train = [to_categorical(seq, num_classes=vocab_size) for seq in X_train]\n","X_val = [to_categorical(seq, num_classes=vocab_size) for seq in X_val]\n","X_test = [to_categorical(seq, num_classes=vocab_size) for seq in X_test]"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"AvdiT_-6rnYF"},"source":["# Padding"]},{"cell_type":"markdown","metadata":{"id":"M384IKVPrnpI"},"source":["Our inputs are all of varying length but we need each batch to be uniformly shaped. Therefore, we will use padding to make all the inputs in the batch the same length. Our padding index will be 0 (note that this is consistent with the `<PAD>` token defined in our `Tokenizer`).\n","\n","> One-hot encoding creates a batch of shape (`N`, `max_seq_len`, `vocab_size`) so we'll need to be able to pad 3D sequences."]},{"cell_type":"code","metadata":{"id":"440sFfwBrnu8"},"source":["def pad_sequences(sequences, max_seq_len=0):\n","    \"\"\"Pad sequences to max length in sequence.\"\"\"\n","    max_seq_len = max(max_seq_len, max(len(sequence) for sequence in sequences))\n","    num_classes = sequences[0].shape[-1]\n","    padded_sequences = np.zeros((len(sequences), max_seq_len, num_classes))\n","    for i, sequence in enumerate(sequences):\n","        padded_sequences[i][:len(sequence)] = sequence\n","    return padded_sequences"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"NeQoMsyUrnyD","executionInfo":{"status":"ok","timestamp":1608329405580,"user_tz":420,"elapsed":10247,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"8aac87eb-f021-4774-8fa3-0f93747d881a"},"source":["# 3D sequences\n","print (X_train[0].shape, X_train[1].shape, X_train[2].shape)\n","padded = pad_sequences(X_train[0:3])\n","print (padded.shape)"],"execution_count":null,"outputs":[{"output_type":"stream","text":["(6, 500) (5, 500) (6, 500)\n","(3, 6, 500)\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"ii6RhseQsMKD"},"source":["# Dataset"]},{"cell_type":"markdown","metadata":{"id":"Kd8PKjN8soc5"},"source":["We're going to place our data into a [`Dataset`](https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset) and use a [`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) to efficiently create batches for training and evaluation."]},{"cell_type":"code","metadata":{"id":"hNOtsJAgwv_4"},"source":["FILTER_SIZE = 1 # unigram"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"VMS8vwTqrn3O"},"source":["class Dataset(torch.utils.data.Dataset):\n","    def __init__(self, X, y, max_filter_size):\n","        self.X = X\n","        self.y = y\n","        self.max_filter_size = max_filter_size\n","\n","    def __len__(self):\n","        return len(self.y)\n","\n","    def __str__(self):\n","        return f\"<Dataset(N={len(self)})>\"\n","\n","    def __getitem__(self, index):\n","        X = self.X[index]\n","        y = self.y[index]\n","        return [X, y]\n","\n","    def collate_fn(self, batch):\n","        \"\"\"Processing on a batch.\"\"\"\n","        # Get inputs\n","        batch = np.array(batch, dtype=object)\n","        X = batch[:, 0]\n","        y = np.stack(batch[:, 1], axis=0)\n","\n","        # Pad sequences\n","        X = pad_sequences(X, max_seq_len=self.max_filter_size)\n","\n","        # Cast\n","        X = torch.FloatTensor(X.astype(np.int32))\n","        y = torch.LongTensor(y.astype(np.int32))\n","\n","        return X, y\n","\n","    def create_dataloader(self, batch_size, shuffle=False, drop_last=False):\n","        return torch.utils.data.DataLoader(\n","            dataset=self, batch_size=batch_size, collate_fn=self.collate_fn,\n","            shuffle=shuffle, drop_last=drop_last, pin_memory=True)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"tXHORNGWrn6L","executionInfo":{"status":"ok","timestamp":1608329405586,"user_tz":420,"elapsed":10188,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"ae6d9ea0-7992-4bb3-e5ee-eeb1184f27af"},"source":["# Create datasets for embedding\n","train_dataset = Dataset(X=X_train, y=y_train, max_filter_size=FILTER_SIZE)\n","val_dataset = Dataset(X=X_val, y=y_val, max_filter_size=FILTER_SIZE)\n","test_dataset = Dataset(X=X_test, y=y_test, max_filter_size=FILTER_SIZE)\n","print (\"Datasets:\\n\"\n","    f\"  Train dataset:{train_dataset.__str__()}\\n\"\n","    f\"  Val dataset: {val_dataset.__str__()}\\n\"\n","    f\"  Test dataset: {test_dataset.__str__()}\\n\"\n","    \"Sample point:\\n\"\n","    f\"  X: {test_dataset[0][0]}\\n\"\n","    f\"  y: {test_dataset[0][1]}\")"],"execution_count":null,"outputs":[{"output_type":"stream","text":["Datasets:\n","  Train dataset:<Dataset(N=84000)>\n","  Val dataset: <Dataset(N=18000)>\n","  Test dataset: <Dataset(N=18000)>\n","Sample point:\n","  X: [[0. 0. 0. ... 0. 0. 0.]\n"," [0. 1. 0. ... 0. 0. 0.]\n"," [0. 1. 0. ... 0. 0. 0.]\n"," [0. 1. 0. ... 0. 0. 0.]]\n","  y: 1\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"77K7kzitqwdX","executionInfo":{"status":"ok","timestamp":1608329416454,"user_tz":420,"elapsed":21026,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"fab54836-13df-49ac-d93f-dbd04ee37245"},"source":["# Create dataloaders\n","batch_size = 64\n","train_dataloader = train_dataset.create_dataloader(batch_size=batch_size)\n","val_dataloader = val_dataset.create_dataloader(batch_size=batch_size)\n","test_dataloader = test_dataset.create_dataloader(batch_size=batch_size)\n","batch_X, batch_y = next(iter(test_dataloader))\n","print (\"Sample batch:\\n\"\n","    f\"  X: {list(batch_X.size())}\\n\"\n","    f\"  y: {list(batch_y.size())}\\n\"\n","    \"Sample point:\\n\"\n","    f\"  X: {batch_X[0]}\\n\"\n","    f\"  y: {batch_y[0]}\")"],"execution_count":null,"outputs":[{"output_type":"stream","text":["Sample batch:\n","  X: [64, 14, 500]\n","  y: [64]\n","Sample point:\n","  X: tensor([[0., 0., 0.,  ..., 0., 0., 0.],\n","        [0., 1., 0.,  ..., 0., 0., 0.],\n","        [0., 1., 0.,  ..., 0., 0., 0.],\n","        ...,\n","        [0., 0., 0.,  ..., 0., 0., 0.],\n","        [0., 0., 0.,  ..., 0., 0., 0.],\n","        [0., 0., 0.,  ..., 0., 0., 0.]], device='cpu')\n","  y: 1\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"Y5FMz_VkzvvG"},"source":["# CNN"]},{"cell_type":"markdown","metadata":{"id":"I6Cs6GYInnD8"},"source":["## Inputs"]},{"cell_type":"markdown","metadata":{"id":"sWf6pSvunpoC"},"source":["We're going to learn about CNNs by applying them on 1D text data. In the dummy example below, our inputs are composed of character tokens that are one-hot encoded. We have a batch of N samples, where each sample has 8 characters and each character is represented by an array of 10 values (`vocab size=10`). This gives our inputs the size `(N, 8, 10)`.\n","\n","> With PyTorch, when dealing with convolution, our inputs (X) need to have the channels as the second dimension, so our inputs will be `(N, 10, 8)`. "]},{"cell_type":"code","metadata":{"id":"c4AW9_QGpBqS"},"source":["import math\n","import torch\n","import torch.nn as nn\n","import torch.nn.functional as F"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"bCQeJEUVnnrR","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1608329416461,"user_tz":420,"elapsed":20989,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"4c09110b-fb0e-45cd-ea0d-c93786a17e98"},"source":["# Assume all our inputs are padded to have the same # of words\n","batch_size = 64\n","max_seq_len = 8 # words per input\n","vocab_size = 10 # one hot size\n","x = torch.randn(batch_size, max_seq_len, vocab_size)\n","print(f\"X: {x.shape}\")\n","x = x.transpose(1, 2)\n","print(f\"X: {x.shape}\")"],"execution_count":null,"outputs":[{"output_type":"stream","text":["X: torch.Size([64, 8, 10])\n","X: torch.Size([64, 10, 8])\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"UAJMhb7DsFgV"},"source":["<div align=\"left\">\n","<img src=\"https://raw.githubusercontent.com/GokuMohandas/MadeWithML/main/images/basics/cnn/inputs.png\" width=\"500\">\n","</div>\n","\n","This diagram above is for char-level tokens but extends to any level of tokenization (word-level in our case)."]},{"cell_type":"markdown","metadata":{"id":"3uLslqxFl_au"},"source":["## Filters"]},{"cell_type":"markdown","metadata":{"id":"JhhTVijAl-Yp"},"source":["At the core of CNNs are filters (aka weights, kernels, etc.) which convolve (slide) across our input to extract relevant features. The filters are initialized randomly but learn to pick up meaningful features from the input that aid in optimizing for the objective. The intuition here is that each filter represents a feature and we will use this filter on other inputs to capture the same feature (feature extraction via parameter sharing). \n","\n","We can see convolution in the diagram below where we simplified the filters and inputs to be 2D for ease of visualization. Also note that the values are 0/1s but in reality they can be any floating point value."]},{"cell_type":"markdown","metadata":{"id":"du4gRM5htR9W"},"source":["<div align=\"left\">\n","<img src=\"https://raw.githubusercontent.com/GokuMohandas/MadeWithML/main/images/basics/cnn/convolution.gif\" width=\"500\">\n","</div>"]},{"cell_type":"markdown","metadata":{"id":"96PfAWzYsOEI"},"source":["Now let's return to our actual inputs `x`, which is of shape (8, 10) [`max_seq_len`, `vocab_size`] and we want to convolve on this input using filters. We will use 50 filters that are of size (1, 3) and has the same depth as the number of channels (`num_channels` = `vocab_size` = `one_hot_size` = 10). This gives our filter a shape of (3, 10, 50) [`kernel_size`, `vocab_size`, `num_filters`]\n","\n","<div align=\"left\">\n","<img src=\"https://raw.githubusercontent.com/GokuMohandas/MadeWithML/main/images/basics/cnn/filters.png\" width=\"500\">\n","</div>"]},{"cell_type":"markdown","metadata":{"id":"26nX0cVX-Bwl"},"source":["* **stride**: amount the filters move from one convolution operation to the next.\n","* **padding**: values (typically zero) padded to the input, typically to create a volume with whole number dimensions."]},{"cell_type":"markdown","metadata":{"id":"q5DH85gdEdR_"},"source":["So far we've used a `stride` of 1 and `VALID` padding (no padding) but let's look at an example with a higher stride and difference between different padding approaches.\n","\n","Padding types:\n","* **VALID**: no padding, the filters only use the \"valid\" values in the input. If the filter cannot reach all the input values (filters go left to right), the extra values on the right are dropped.\n","* **SAME**: adds padding evenly to the right (preferred) and left sides of the input so that all values in the input are processed.\n","\n","<div align=\"left\">\n","<img src=\"https://raw.githubusercontent.com/GokuMohandas/MadeWithML/main/images/basics/cnn/padding.png\" width=\"500\">\n","</div>"]},{"cell_type":"markdown","metadata":{"id":"wN7QCnCWuwiG"},"source":["We're going to use the [Conv1d](https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv1d) layer to process our inputs."]},{"cell_type":"code","metadata":{"id":"0bwK2BE6diB3","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1608329416464,"user_tz":420,"elapsed":20967,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"02f074d8-95bb-415d-9e8b-43bce2a84e4a"},"source":["# Convolutional filters (VALID padding)\n","vocab_size = 10 # one hot size\n","num_filters = 50 # num filters\n","filter_size = 3 # filters are 3X3\n","stride = 1\n","padding = 0 # valid padding (no padding)\n","conv1 = nn.Conv1d(in_channels=vocab_size, out_channels=num_filters, \n","                  kernel_size=filter_size, stride=stride, \n","                  padding=padding, padding_mode='zeros')\n","print(\"conv: {}\".format(conv1.weight.shape))"],"execution_count":null,"outputs":[{"output_type":"stream","text":["conv: torch.Size([50, 10, 3])\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"HYxn8MejfGi0","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1608329416467,"user_tz":420,"elapsed":20947,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"28db5639-ac9c-4882-840f-be851a066f94"},"source":["# Forward pass\n","z = conv1(x)\n","print (f\"z: {z.shape}\")"],"execution_count":null,"outputs":[{"output_type":"stream","text":["z: torch.Size([64, 50, 6])\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"wFD9qBslxs4A"},"source":["<div align=\"left\">\n","<img src=\"https://raw.githubusercontent.com/GokuMohandas/MadeWithML/main/images/basics/cnn/conv.png\" width=\"700\">\n","</div>"]},{"cell_type":"markdown","metadata":{"id":"tcBbTPW6tZtr"},"source":["When we apply these filter on our inputs, we receive an output of shape (N, 6, 50). We get 50 for the output channel dim because we used 50 filters and 6 for the conv outputs because:\n","\n","$W_1 = \\frac{W_2 - F + 2P}{S} + 1 = \\frac{8 - 3 + 2(0)}{1} + 1 = 6$\n","\n","$H_1 = \\frac{H_2 - F + 2P}{S} + 1 = \\frac{1 - 1 + 2(0)}{1} + 1 = 1$\n","\n","$D_2 = D_1 $\n","\n","where:\n","  * `W`: width of each input = 8\n","  * `H`: height of each input = 1\n","  * `D`: depth (# channels)\n","  * `F`: filter size = 3\n","  * `P`: padding = 0\n","  * `S`: stride = 1"]},{"cell_type":"markdown","metadata":{"id":"NSE9tWhUKHPb"},"source":["Now we'll add padding so that the convolutional outputs are the same shape as our inputs. The amount of padding for the `SAME` padding can be determined using the same equation. We want out output to have the same width as our input, so we solve for P:\n","\n","$ \\frac{W-F+2P}{S} + 1 = W $\n","\n","$ P = \\frac{S(W-1) - W + F}{2} $\n","\n","If $P$ is not a whole number, we round up (using `math.ceil`) and place the extra padding on the right side."]},{"cell_type":"code","metadata":{"id":"uXB9FwR6EkeA","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1608329416470,"user_tz":420,"elapsed":20927,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"f7af9beb-2084-4857-f340-1d1ae8d3139b"},"source":["# Convolutional filters (SAME padding)\n","vocab_size = 10 # one hot size\n","num_filters = 50 # num filters\n","filter_size = 3 # filters are 3X3\n","stride = 1\n","conv = nn.Conv1d(in_channels=vocab_size, out_channels=num_filters, \n","                 kernel_size=filter_size, stride=stride)\n","print(\"conv: {}\".format(conv.weight.shape))"],"execution_count":null,"outputs":[{"output_type":"stream","text":["conv: torch.Size([50, 10, 3])\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"HQLr0RG9uIgK","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1608329416472,"user_tz":420,"elapsed":20905,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"0574101c-1413-491b-dd80-6994378ab32a"},"source":["# `SAME` padding\n","padding_left = int((conv.stride[0]*(max_seq_len-1) - max_seq_len + filter_size)/2)\n","padding_right = int(math.ceil((conv.stride[0]*(max_seq_len-1) - max_seq_len + filter_size)/2))\n","print (f\"padding: {(padding_left, padding_right)}\")"],"execution_count":null,"outputs":[{"output_type":"stream","text":["padding: (1, 1)\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"KCxWRzxAEmJW","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1608329416475,"user_tz":420,"elapsed":20879,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"26fe7b05-2059-4599-98ad-0f666556a1d4"},"source":["# Forward pass\n","z = conv(F.pad(x, (padding_left, padding_right)))\n","print (f\"z: {z.shape}\")"],"execution_count":null,"outputs":[{"output_type":"stream","text":["z: torch.Size([64, 50, 8])\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"FRD72rMHvWHN"},"source":["> We will explore larger dimensional convolution layers in subsequent lessons. For example, [Conv2D](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html#torch.nn.Conv2d) is used with 3D inputs (images, char-level text, etc.) and [Conv3D](https://pytorch.org/docs/stable/generated/torch.nn.Conv3d.html#torch.nn.Conv3d) is used for 4D inputs (videos, time-series, etc.)."]},{"cell_type":"markdown","metadata":{"id":"HpB8HSJNwLp-"},"source":["## Pooling"]},{"cell_type":"markdown","metadata":{"id":"WxRnbPy6wLuc"},"source":["The result of convolving filters on an input is a feature map. Due to the nature of convolution and overlaps, our feature map will have lots of redundant information. Pooling is a way to summarize a high-dimensional feature map into a lower dimensional one for simplified downstream computation. The pooling operation can be the max value, average, etc. in a certain receptive field. Below is an example of pooling where the outputs from a conv layer are `4X4` and we're going to apply max pool filters of size `2X2`.\n","\n","<div align=\"left\">\n","<img src=\"https://raw.githubusercontent.com/GokuMohandas/MadeWithML/main/images/basics/cnn/pooling.png\" width=\"500\">\n","</div>"]},{"cell_type":"markdown","metadata":{"id":"ok0EBECKc2QU"},"source":["$W_2 = \\frac{W_1 - F}{S} + 1 = \\frac{4 - 2}{2} + 1 = 2$\n","\n","$H_2 = \\frac{H_1 - F}{S} + 1 = \\frac{4 - 2}{2} + 1 = 2$\n","\n","$ D_2 = D_1 $\n","\n","where:\n","  * `W`: width of each input = 4\n","  * `H`: height of each input = 4\n","  * `D`: depth (# channels)\n","  * `F`: filter size = 2\n","  * `S`: stride = 2"]},{"cell_type":"markdown","metadata":{"id":"5ijJtky9QHeX"},"source":["In our use case, we want to just take the one max value so we will use the [MaxPool1D](https://pytorch.org/docs/stable/generated/torch.nn.MaxPool1d.html#torch.nn.MaxPool1d) layer, so our max-pool filter size will be max_seq_len.\n"]},{"cell_type":"code","metadata":{"id":"niptcsv2wUPA","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1608329416477,"user_tz":420,"elapsed":20856,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"40ff5268-8688-4536-8c20-9a2355a009fa"},"source":["# Max pooling\n","pool_output = F.max_pool1d(z, z.size(2))\n","print(\"Size: {}\".format(pool_output.shape))"],"execution_count":null,"outputs":[{"output_type":"stream","text":["Size: torch.Size([64, 50, 1])\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"ccelFfH-s3ZY"},"source":["## Batch Normalization"]},{"cell_type":"markdown","metadata":{"id":"9f67F4o1HHQp"},"source":["The last topic we'll cover before constructing our model is [batch normalization](https://arxiv.org/abs/1502.03167). It's an operation that will standardize (mean=0, std=1) the activations from the previous layer. Recall that we used to standardize our inputs in previous notebooks so our model can optimize quickly with larger learning rates. It's the same concept here but we continue to maintain standardized values throughout the forward pass to further aid optimization. "]},{"cell_type":"code","metadata":{"id":"owtCbYoZs82g","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1608329416479,"user_tz":420,"elapsed":20834,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"c2399f3a-a21a-45a8-8ddf-873ed6d6163b"},"source":["# Batch normalization\n","batch_norm = nn.BatchNorm1d(num_features=num_filters)\n","z = batch_norm(conv(x)) # applied to activations (after conv layer & before pooling)\n","print (f\"z: {z.shape}\")"],"execution_count":null,"outputs":[{"output_type":"stream","text":["z: torch.Size([64, 50, 6])\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"3ABUSm-vyaTG","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1608329416483,"user_tz":420,"elapsed":20812,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"405e600d-e99f-46f1-89ad-cb713332270b"},"source":["# Mean and std before batchnorm\n","print (f\"mean: {torch.mean(conv1(x)):.2f}, std: {torch.std(conv(x)):.2f}\")"],"execution_count":null,"outputs":[{"output_type":"stream","text":["mean: 0.01, std: 0.57\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"fNzIKpJUyQqk","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1608329416486,"user_tz":420,"elapsed":20789,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"fca1b7e5-a699-4e12-b3c3-3e288bac2afb"},"source":["# Mean and std after batchnorm\n","print (f\"mean: {torch.mean(z):.2f}, std: {torch.std(z):.2f}\")"],"execution_count":null,"outputs":[{"output_type":"stream","text":["mean: 0.00, std: 1.00\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"Hpo4QXTOtGV6"},"source":["# Modeling"]},{"cell_type":"markdown","metadata":{"id":"pfhjWZRD94hK"},"source":["## Model"]},{"cell_type":"markdown","metadata":{"id":"zVmJGm8m-KIz"},"source":["Let's visualize the model's forward pass.\n","\n","1. We'll first tokenize our inputs (`batch_size`, `max_seq_len`).\n","2. Then we'll one-hot encode our tokenized inputs (`batch_size`, `max_seq_len`, `vocab_size`).\n","3. We'll apply convolution via filters (`filter_size`, `vocab_size`, `num_filters`) followed by batch normalization. Our filters act as character level n-gram detectors.\n","4. We'll apply 1D global max pooling which will extract the most relevant information from the feature maps for making the decision.\n","5. We feed the pool outputs to a fully-connected (FC) layer (with dropout).\n","6. We use one more FC layer with softmax to derive class probabilities. "]},{"cell_type":"markdown","metadata":{"id":"3ilKr0yTgl6o"},"source":["<div align=\"left\">\n","<img src=\"https://raw.githubusercontent.com/GokuMohandas/MadeWithML/main/images/basics/cnn/model.png\" width=\"1000\">\n","</div>"]},{"cell_type":"code","metadata":{"id":"jKXRELvlwQHe"},"source":["NUM_FILTERS = 50\n","HIDDEN_DIM = 100\n","DROPOUT_P = 0.1"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"UPP5ROd69mXC"},"source":["class CNN(nn.Module):\n","    def __init__(self, vocab_size, num_filters, filter_size,\n","                 hidden_dim, dropout_p, num_classes):\n","        super(CNN, self).__init__()\n","        \n","        # Convolutional filters\n","        self.filter_size = filter_size\n","        self.conv = nn.Conv1d(\n","            in_channels=vocab_size, out_channels=num_filters, \n","            kernel_size=filter_size, stride=1, padding=0, padding_mode='zeros')\n","        self.batch_norm = nn.BatchNorm1d(num_features=num_filters)\n","\n","        # FC layers\n","        self.fc1 = nn.Linear(num_filters, hidden_dim)\n","        self.dropout = nn.Dropout(dropout_p)\n","        self.fc2 = nn.Linear(hidden_dim, num_classes)\n","\n","    def forward(self, inputs, channel_first=False, apply_softmax=False):\n","\n","        # Rearrange input so num_channels is in dim 1 (N, C, L)\n","        x_in, = inputs\n","        if not channel_first:\n","            x_in = x_in.transpose(1, 2)\n","\n","        # Padding for `SAME` padding\n","        max_seq_len = x_in.shape[2]\n","        padding_left = int((self.conv.stride[0]*(max_seq_len-1) - max_seq_len + self.filter_size)/2)\n","        padding_right = int(math.ceil((self.conv.stride[0]*(max_seq_len-1) - max_seq_len + self.filter_size)/2))\n","\n","        # Conv outputs\n","        z = self.conv(F.pad(x_in, (padding_left, padding_right)))\n","        z = F.max_pool1d(z, z.size(2)).squeeze(2)\n","\n","        # FC layer\n","        z = self.fc1(z)\n","        z = self.dropout(z)\n","        y_pred = self.fc2(z)\n","\n","        if apply_softmax:\n","            y_pred = F.softmax(y_pred, dim=1)\n","        return y_pred"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"wD4sRUS5_lwq","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1608329416494,"user_tz":420,"elapsed":20741,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"a22a4bc0-7c3a-4228-b6ba-9e8852a0aa2d"},"source":["# Initialize model\n","model = CNN(vocab_size=VOCAB_SIZE, num_filters=NUM_FILTERS, filter_size=FILTER_SIZE,\n","            hidden_dim=HIDDEN_DIM, dropout_p=DROPOUT_P, num_classes=NUM_CLASSES)\n","model = model.to(device) # set device\n","print (model.named_parameters)"],"execution_count":null,"outputs":[{"output_type":"stream","text":["<bound method Module.named_parameters of CNN(\n","  (conv): Conv1d(500, 50, kernel_size=(1,), stride=(1,))\n","  (batch_norm): BatchNorm1d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n","  (fc1): Linear(in_features=50, out_features=100, bias=True)\n","  (dropout): Dropout(p=0.1, inplace=False)\n","  (fc2): Linear(in_features=100, out_features=4, bias=True)\n",")>\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"s1L7-vxbgCup"},"source":["> We used `SAME` padding (w/ stride=1) which means that the conv outputs will have the same width (`max_seq_len`) as our inputs. The amount of padding differs for each batch based on the `max_seq_len` but you can calculate it by solving for P in the equation below.\n","\n","$ \\frac{W_1 - F + 2P}{S} + 1 = W_2 $\n","\n","$ \\frac{\\text{max_seq_len } - \\text{ filter_size } + 2P}{\\text{stride}} + 1 = \\text{max_seq_len} $\n","\n","$ P = \\frac{\\text{stride}(\\text{max_seq_len}-1) - \\text{max_seq_len} + \\text{filter_size}}{2} $\n","\n","If $P$ is not a whole number, we round up (using `math.ceil`) and place the extra padding on the right side."]},{"cell_type":"markdown","metadata":{"id":"tzZTZI-y2Tzy"},"source":["## Training"]},{"cell_type":"markdown","metadata":{"id":"SYw0JY9k2VPk"},"source":["Let's create the `Trainer` class that we'll use to facilitate training for our experiments. Notice that we're now moving the `train` function inside this class."]},{"cell_type":"code","metadata":{"id":"2Cd0KoTq3MhV"},"source":["from torch.optim import Adam"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"NPUmSaoR3PQc"},"source":["LEARNING_RATE = 1e-3\n","PATIENCE = 5\n","NUM_EPOCHS = 10"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"H3EqfumC2Ud4"},"source":["class Trainer(object):\n","    def __init__(self, model, device, loss_fn=None, optimizer=None, scheduler=None):\n","\n","        # Set params\n","        self.model = model\n","        self.device = device\n","        self.loss_fn = loss_fn\n","        self.optimizer = optimizer\n","        self.scheduler = scheduler\n","\n","    def train_step(self, dataloader):\n","        \"\"\"Train step.\"\"\"\n","        # Set model to train mode\n","        self.model.train()\n","        loss = 0.0\n","\n","        # Iterate over train batches\n","        for i, batch in enumerate(dataloader):\n","\n","            # Step\n","            batch = [item.to(self.device) for item in batch]  # Set device\n","            inputs, targets = batch[:-1], batch[-1]\n","            self.optimizer.zero_grad()  # Reset gradients\n","            z = self.model(inputs)  # Forward pass\n","            J = self.loss_fn(z, targets)  # Define loss\n","            J.backward()  # Backward pass\n","            self.optimizer.step()  # Update weights\n","\n","            # Cumulative Metrics\n","            loss += (J.detach().item() - loss) / (i + 1)\n","\n","        return loss\n","\n","    def eval_step(self, dataloader):\n","        \"\"\"Validation or test step.\"\"\"\n","        # Set model to eval mode\n","        self.model.eval()\n","        loss = 0.0\n","        y_trues, y_probs = [], []\n","\n","        # Iterate over val batches\n","        with torch.no_grad():\n","            for i, batch in enumerate(dataloader):\n","\n","                # Step\n","                batch = [item.to(self.device) for item in batch]  # Set device\n","                inputs, y_true = batch[:-1], batch[-1]\n","                z = self.model(inputs)  # Forward pass\n","                J = self.loss_fn(z, y_true).item()\n","\n","                # Cumulative Metrics\n","                loss += (J - loss) / (i + 1)\n","\n","                # Store outputs\n","                y_prob = torch.sigmoid(z).cpu().numpy()\n","                y_probs.extend(y_prob)\n","                y_trues.extend(y_true.cpu().numpy())\n","\n","        return loss, np.vstack(y_trues), np.vstack(y_probs)\n","\n","    def predict_step(self, dataloader):\n","        \"\"\"Prediction step.\"\"\"\n","        # Set model to eval mode\n","        self.model.eval()\n","        y_probs = []\n","\n","        # Iterate over val batches\n","        with torch.no_grad():\n","            for i, batch in enumerate(dataloader):\n","\n","                # Forward pass w/ inputs\n","                inputs, targets = batch[:-1], batch[-1]\n","                y_prob = self.model(inputs, apply_softmax=True)\n","\n","                # Store outputs\n","                y_probs.extend(y_prob)\n","\n","        return np.vstack(y_probs)\n","    \n","    def train(self, num_epochs, patience, train_dataloader, val_dataloader):\n","        best_val_loss = np.inf\n","        for epoch in range(num_epochs):\n","            # Steps\n","            train_loss = self.train_step(dataloader=train_dataloader)\n","            val_loss, _, _ = self.eval_step(dataloader=val_dataloader)\n","            self.scheduler.step(val_loss)\n","\n","            # Early stopping\n","            if val_loss < best_val_loss:\n","                best_val_loss = val_loss\n","                best_model = self.model\n","                _patience = patience  # reset _patience\n","            else:\n","                _patience -= 1\n","            if not _patience:  # 0\n","                print(\"Stopping early!\")\n","                break\n","\n","            # Logging\n","            print(\n","                f\"Epoch: {epoch+1} | \"\n","                f\"train_loss: {train_loss:.5f}, \"\n","                f\"val_loss: {val_loss:.5f}, \"\n","                f\"lr: {self.optimizer.param_groups[0]['lr']:.2E}, \"\n","                f\"_patience: {_patience}\"\n","            )\n","        return best_model"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"SMfTLWuB290z"},"source":["# Define Loss\n","class_weights_tensor = torch.Tensor(list(class_weights.values())).to(device)\n","loss = nn.CrossEntropyLoss(weight=class_weights_tensor)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"wAjAGYpX3FG5"},"source":["# Define optimizer & scheduler\n","optimizer = Adam(model.parameters(), lr=LEARNING_RATE) \n","scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n","    optimizer, mode='min', factor=0.1, patience=3)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"GoA55gCt3FJk"},"source":["# Trainer module\n","trainer = Trainer(\n","    model=model, device=device, loss_fn=loss_fn, \n","    optimizer=optimizer, scheduler=scheduler)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"ub6AB1qB3JCm","executionInfo":{"status":"ok","timestamp":1608329481433,"user_tz":420,"elapsed":85574,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"7763483e-d152-4530-ccf3-199b05bb54a0"},"source":["# Train\n","best_model = trainer.train(\n","    NUM_EPOCHS, PATIENCE, train_dataloader, val_dataloader)"],"execution_count":null,"outputs":[{"output_type":"stream","text":["Epoch: 1 | train_loss: 0.87388, val_loss: 0.79013, lr: 1.00E-03, _patience: 5\n","Epoch: 2 | train_loss: 0.78354, val_loss: 0.78657, lr: 1.00E-03, _patience: 5\n","Epoch: 3 | train_loss: 0.77743, val_loss: 0.78433, lr: 1.00E-03, _patience: 5\n","Epoch: 4 | train_loss: 0.77242, val_loss: 0.78260, lr: 1.00E-03, _patience: 5\n","Epoch: 5 | train_loss: 0.76900, val_loss: 0.78169, lr: 1.00E-03, _patience: 5\n","Epoch: 6 | train_loss: 0.76613, val_loss: 0.78064, lr: 1.00E-03, _patience: 5\n","Epoch: 7 | train_loss: 0.76413, val_loss: 0.78019, lr: 1.00E-03, _patience: 5\n","Epoch: 8 | train_loss: 0.76215, val_loss: 0.78016, lr: 1.00E-03, _patience: 5\n","Epoch: 9 | train_loss: 0.76034, val_loss: 0.77974, lr: 1.00E-03, _patience: 5\n","Epoch: 10 | train_loss: 0.75859, val_loss: 0.77978, lr: 1.00E-03, _patience: 4\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"fG6ejmWj4DhH"},"source":["## Evaluation"]},{"cell_type":"code","metadata":{"id":"8k2WdDeH3S6P"},"source":["import json\n","from pathlib import Path\n","from sklearn.metrics import precision_recall_fscore_support"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"Sne12LQZ4HTx"},"source":["def get_performance(y_true, y_pred, classes):\n","    \"\"\"Per-class performance metrics.\"\"\"\n","    # Performance\n","    performance = {\"overall\": {}, \"class\": {}}\n","\n","    # Overall performance\n","    metrics = precision_recall_fscore_support(y_true, y_pred, average=\"weighted\")\n","    performance[\"overall\"][\"precision\"] = metrics[0]\n","    performance[\"overall\"][\"recall\"] = metrics[1]\n","    performance[\"overall\"][\"f1\"] = metrics[2]\n","    performance[\"overall\"][\"num_samples\"] = np.float64(len(y_true))\n","\n","    # Per-class performance\n","    metrics = precision_recall_fscore_support(y_true, y_pred, average=None)\n","    for i in range(len(classes)):\n","        performance[\"class\"][classes[i]] = {\n","            \"precision\": metrics[0][i],\n","            \"recall\": metrics[1][i],\n","            \"f1\": metrics[2][i],\n","            \"num_samples\": np.float64(metrics[3][i]),\n","        }\n","\n","    return performance"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"TXRbjnaz4I1H"},"source":["# Get predictions\n","test_loss, y_true, y_prob = trainer.eval_step(dataloader=test_dataloader)\n","y_pred = np.argmax(y_prob, axis=1)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"b3fVoGmZ5WRR","executionInfo":{"status":"ok","timestamp":1608329482415,"user_tz":420,"elapsed":86486,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"925933ec-16e1-4dec-914c-4ce5307a1cab"},"source":["# Determine performance\n","performance = get_performance(\n","    y_true=y_test, y_pred=y_pred, classes=label_encoder.classes)\n","print (json.dumps(performance['overall'], indent=2))"],"execution_count":null,"outputs":[{"output_type":"stream","text":["{\n","  \"precision\": 0.7120047175492572,\n","  \"recall\": 0.6935,\n","  \"f1\": 0.6931471439737603,\n","  \"num_samples\": 18000.0\n","}\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"wIgO8aQd46p1"},"source":["# Save artifacts\n","dir = Path(\"cnn\")\n","dir.mkdir(parents=True, exist_ok=True)\n","label_encoder.save(fp=Path(dir, 'label_encoder.json'))\n","tokenizer.save(fp=Path(dir, 'tokenizer.json'))\n","torch.save(best_model.state_dict(), Path(dir, 'model.pt'))\n","with open(Path(dir, 'performance.json'), \"w\") as fp:\n","    json.dump(performance, indent=2, sort_keys=False, fp=fp)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"9SYtJtBB42jA"},"source":["## Inference"]},{"cell_type":"code","metadata":{"id":"u0pGgxCN9W9I"},"source":["def get_probability_distribution(y_prob, classes):\n","    \"\"\"Create a dict of class probabilities from an array.\"\"\"\n","    results = {}\n","    for i, class_ in enumerate(classes):\n","        results[class_] = np.float64(y_prob[i])\n","    sorted_results = {k: v for k, v in sorted(\n","        results.items(), key=lambda item: item[1], reverse=True)}\n","    return sorted_results"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"WciQjoMf4KZL","executionInfo":{"status":"ok","timestamp":1608329482420,"user_tz":420,"elapsed":86434,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"48942e17-41b4-47f9-989e-a100c28c4a8d"},"source":["# Load artifacts\n","device = torch.device(\"cpu\")\n","label_encoder = LabelEncoder.load(fp=Path(dir, 'label_encoder.json'))\n","tokenizer = Tokenizer.load(fp=Path(dir, 'tokenizer.json'))\n","model = CNN(\n","    vocab_size=VOCAB_SIZE, num_filters=NUM_FILTERS, filter_size=FILTER_SIZE,\n","    hidden_dim=HIDDEN_DIM, dropout_p=DROPOUT_P, num_classes=NUM_CLASSES)\n","model.load_state_dict(torch.load(Path(dir, 'model.pt'), map_location=device))\n","model.to(device)"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["CNN(\n","  (conv): Conv1d(500, 50, kernel_size=(1,), stride=(1,))\n","  (batch_norm): BatchNorm1d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n","  (fc1): Linear(in_features=50, out_features=100, bias=True)\n","  (dropout): Dropout(p=0.1, inplace=False)\n","  (fc2): Linear(in_features=100, out_features=4, bias=True)\n",")"]},"metadata":{"tags":[]},"execution_count":63}]},{"cell_type":"code","metadata":{"id":"5uioiNWd5akm"},"source":["# Initialize trainer\n","trainer = Trainer(model=model, device=device)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"rKHHNuw65dPx","executionInfo":{"status":"ok","timestamp":1608329482425,"user_tz":420,"elapsed":86395,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"96ea1657-6aab-405f-f085-70d80cf7a92d"},"source":["# Dataloader\n","text = \"What a day for the new york stock market to go bust!\"\n","sequences = tokenizer.texts_to_sequences([preprocess(text)])\n","print (tokenizer.sequences_to_texts(sequences))\n","X = [to_categorical(seq, num_classes=len(tokenizer)) for seq in sequences]\n","y_filler = label_encoder.encode([label_encoder.classes[0]]*len(X))\n","dataset = Dataset(X=X, y=y_filler, max_filter_size=FILTER_SIZE)\n","dataloader = dataset.create_dataloader(batch_size=batch_size)"],"execution_count":null,"outputs":[{"output_type":"stream","text":["['day new <UNK> stock market go <UNK>']\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"QcZnAxrm5dSY","executionInfo":{"status":"ok","timestamp":1608329482426,"user_tz":420,"elapsed":86367,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"0be1317f-1065-4340-d8cf-f020075bd7a9"},"source":["# Inference\n","y_prob = trainer.predict_step(dataloader)\n","y_pred = np.argmax(y_prob, axis=1)\n","label_encoder.decode(y_pred)"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["['Business']"]},"metadata":{"tags":[]},"execution_count":66}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"WhjciE436mYQ","executionInfo":{"status":"ok","timestamp":1608329482427,"user_tz":420,"elapsed":86340,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"04a9bc66-9692-4fac-ecc0-991a84686c55"},"source":["# Class distributions\n","prob_dist = get_probability_distribution(y_prob=y_prob[0], classes=label_encoder.classes)\n","print (json.dumps(prob_dist, indent=2))"],"execution_count":null,"outputs":[{"output_type":"stream","text":["{\n","  \"Business\": 0.8670833110809326,\n","  \"Sci/Tech\": 0.10699427127838135,\n","  \"World\": 0.021050667390227318,\n","  \"Sports\": 0.004871787969022989\n","}\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"1TsxDKyRArrQ"},"source":["# Interpretability"]},{"cell_type":"markdown","metadata":{"id":"wqa6OxyeAtkj"},"source":["We went through all the trouble of padding our inputs before convolution to result is outputs of the same shape as our inputs so we can try to get some interpretability. Since every token is mapped to a convolutional output on whcih we apply max pooling, we can see which token's output was most influential towards the prediction. We first need to get the conv outputs from our model:"]},{"cell_type":"code","metadata":{"id":"T-TywFzL54LS"},"source":["import collections\n","import seaborn as sns"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"-I6gSfZ9BSOD"},"source":["class InterpretableCNN(nn.Module):\n","    def __init__(self, vocab_size, num_filters, filter_size,\n","                 hidden_dim, dropout_p, num_classes):\n","        super(InterpretableCNN, self).__init__()\n","        \n","        # Convolutional filters\n","        self.filter_size = filter_size\n","        self.conv = nn.Conv1d(\n","            in_channels=vocab_size, out_channels=num_filters, \n","            kernel_size=filter_size, stride=1, padding=0, padding_mode='zeros')\n","        self.batch_norm = nn.BatchNorm1d(num_features=num_filters)\n","\n","        # FC layers\n","        self.fc1 = nn.Linear(num_filters, hidden_dim)\n","        self.dropout = nn.Dropout(dropout_p)\n","        self.fc2 = nn.Linear(hidden_dim, num_classes)\n","\n","    def forward(self, inputs, channel_first=False, apply_softmax=False):\n","\n","        # Rearrange input so num_channels is in dim 1 (N, C, L)\n","        x_in, = inputs\n","        if not channel_first:\n","            x_in = x_in.transpose(1, 2)\n","\n","        # Padding for `SAME` padding\n","        max_seq_len = x_in.shape[2]\n","        padding_left = int((self.conv.stride[0]*(max_seq_len-1) - max_seq_len + self.filter_size)/2)\n","        padding_right = int(math.ceil((self.conv.stride[0]*(max_seq_len-1) - max_seq_len + self.filter_size)/2))\n","\n","        # Conv outputs\n","        z = self.conv(F.pad(x_in, (padding_left, padding_right)))\n","        \n","        return z"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"kdgTQI3qBwVe"},"source":["# Initialize\n","interpretable_model = InterpretableCNN(\n","    vocab_size=len(tokenizer), num_filters=NUM_FILTERS, filter_size=FILTER_SIZE,\n","    hidden_dim=HIDDEN_DIM, dropout_p=DROPOUT_P, num_classes=NUM_CLASSES)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"iF8ckOT8BSQ5","executionInfo":{"status":"ok","timestamp":1608329482431,"user_tz":420,"elapsed":86287,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"cab4e23e-4634-4081-acc3-15cb3a9402ea"},"source":["# Load weights (same architecture)\n","interpretable_model.load_state_dict(torch.load(Path(dir, 'model.pt'), map_location=device))\n","interpretable_model.to(device)"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["InterpretableCNN(\n","  (conv): Conv1d(500, 50, kernel_size=(1,), stride=(1,))\n","  (batch_norm): BatchNorm1d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n","  (fc1): Linear(in_features=50, out_features=100, bias=True)\n","  (dropout): Dropout(p=0.1, inplace=False)\n","  (fc2): Linear(in_features=100, out_features=4, bias=True)\n",")"]},"metadata":{"tags":[]},"execution_count":70}]},{"cell_type":"code","metadata":{"id":"kFDPbdHLCJpa"},"source":["# Initialize trainer\n","interpretable_trainer = Trainer(model=interpretable_model, device=device)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"defuFVDYBSTn","executionInfo":{"status":"ok","timestamp":1608329513133,"user_tz":420,"elapsed":609,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"03d0a31e-d853-42ff-bd7b-fe42282f7d12"},"source":["# Get conv outputs\n","conv_outputs = interpretable_trainer.predict_step(dataloader)\n","print (conv_outputs.shape) # (num_filters, max_seq_len)"],"execution_count":null,"outputs":[{"output_type":"stream","text":["(50, 7)\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":284},"id":"S8BeyVdk6Mzd","executionInfo":{"status":"ok","timestamp":1608329546685,"user_tz":420,"elapsed":685,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"cca94fd3-d355-4ac1-fdc6-0a573bbf90ec"},"source":["# Visualize a bi-gram filter's outputs\n","tokens = tokenizer.sequences_to_texts(sequences)[0].split(' ')\n","sns.heatmap(conv_outputs, xticklabels=tokens)"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["<matplotlib.axes._subplots.AxesSubplot at 0x7f832b27eb70>"]},"metadata":{"tags":[]},"execution_count":75},{"output_type":"display_data","data":{"image/png":"iVBORw0KGgoAAAANSUhEUgAAAWYAAAD6CAYAAACS9e2aAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3de5xVdbnH8c+XqyJyS0UUTFJMQwsVMTt5OWmGlWLmLTUlIbocO92sLC3TrDTzeKxTR/FKWXaUUvGGIqCoaYLmBRTEK6CggqIicpmZ5/yx1tBmWHvW2jNr7/WbPc+b13qxZq/bMzNrnv3bv/Vbz5KZ4ZxzLhxdig7AOefcxjwxO+dcYDwxO+dcYDwxO+dcYDwxO+dcYDwxO+dcYDwxO+dcYLqlrSBpV2AMsH380svAFDN7OssBrt3upCAHSr8b6FvSFk1FR9DxrFfREST71PtfLjqERDNe3D59pYKc9Mq17f5trl/+fOac032rDwR59rSaniT9APgLIODheBJwnaQzqh+ec851Pmkt5nHAcDNbX/qipP8C5gHnVysw55xrk6bGoiNot7QP9E3AdgmvD4qXJZI0QdIcSXNmrF7Ynvicc64yjQ3Zp0CltZi/BUyXtBBYHL+2A7AzcFq5jcxsIjAR4PLBJ9l7OQSat1D7JQMNC4BtGtenr1SAj37+7aJDSDTpljD7cvsXHUCVmXX8CzWtJmYzmyppF2AUG1/8m21mHf/zgsss1KTs3Caa6jwxA1j09vNQDWJxzrn2q/cWs3POdTh1cPHPE7Nzrr54izld76Yg7y+hsWuYl9m6Bfrggv1v/0LRIZRlS58vOoRE2/1tTtEhJFqvMM/9vFjAoy2y8hazc66+dIaLf84516F4V4ZzzgWmM1z8kzQKMDObLelDwGhgvpndnuUAu/cMc/C/FGZf7hYD1hYdQqJnj7mi6BDK6hLo73L0dwcWHUKiV69dnL5SR1bvLWZJZwOHAd0kTQP2BWYCZ0ja08x+XoMYnXMuuxwv/kkaDVwCdAWuMLPzWyz/DjAeaABeB041s5fae9y0FvPRwAigJ7AMGGxmb0v6NfAPwBOzcy4sOV38k9QV+B3wSWAJMFvSFDN7qmS1fwIjzWy1pK8BvwKOa++x04oYNZhZo5mtBp4zs7cBzOw9MhYxmvxOu988nHMuM7PGzFOKUcCzZva8ma0jKoE8ZuNj2cw4P0J0h/TgPL6HtBbzOkm94gPv3fyipL60kphLixhdssNJNiPAvvjVgQ7lbFgZZn8pwPdvObnoEBJN/Mwfiw4h0T4Xryw6hETTum9TdAhl/TiPnVTQxyxpAjCh5KWJcf6CqD5QaYf8EqLu3HLGAXdkPngr0hLzAWa2FjbUzGjWHTgljwBcxxBqUnZuExV0ZZQ2IttD0knASODA9u4L0qvLJQ4RMLPlwPI8AnDOuVzlNyrjZWBIydeD49c2IukQ4EzgwHI5s1I+jtk5V1/yK1E7GxgmaShRQj4eOKF0BUl7ApcBo83stbwO7InZOVdfchqVYWYNkk4D7iQaLneVmc2TdC4wx8ymABcCvYEbFNUgWWRmR7T32LIqF8157+rvB3k160fnLCo6hERTV4dZkOeUzXcpOoSy1hPkKcZc3i06hERPvPdK0SGU9fRrD7f7svyaB6/LfEJstt8XghwG4C1m51x98SJGzjkXGE/MzjkXFquD51NWPTG/MTHMYuHf3irIriV+1H/zokNI9NbLrxcdQlndNwvwDiYCjmuLMM+x3NRBEaNWb8mWtK+kPvH85pLOkXSLpAviu/+ccy4sTU3Zp0Cl1cq4Cmi+D/wSoC9wQfza1VWMyznn2saask+BSuvK6GJmzTX0RprZXvH8/ZIeK7dR6f3n5w/9ICcN3K79kTrnXBYBt4SzSkvMcyV9ycyuBh6XNNLM5kjaBSjbw156//m1251kMwMsMHfU6WH2s039VZhjXwE+PnyTu1GDcP+87YsOIdGBe4f587r3kTB/XgBH5bGTgFvCWaUl5vHAJZLOIqqN8aCkxUQVl8ZXOzgXjlCTsnObaKjzp2Sb2VvA2PgC4NB4/SVm9motgnPOuYp1ghYzAHGB/MerHItzzrVfJ+hjds65jqWztJjbY/fuYT4l+9eXhFn4ZvtuaSMYi9Hnwu8WHUJZhz16T9EhJPr9uWH+LvvVe3PMW8zOORcYbzE751xg6n1UhnPOdThVrjFfC1VPzK+vCfNGjs/1CrPv+0b6FB1ComWn/rboEMra+jP9ig4h0VH9w3xK9h9XhvuU7Fx0hj5mSR8guiFnCNAIPAP8OR5C55xzYamDxJxWXe4/gUuBzYB9gJ5ECfohSQdVPTrnnKtUHRQxShvP82XgMDM7DzgEGG5mZwKjgYvLbSRpgqQ5kubc9t5z+UXrnHNpGhuzT4HK0sfcjagLoyfR02Axs0WSupfboLSI0V0Djw+yK35GY5jlpBd1WVN0CInOW9mX3/xgUNFhJGp8+oWiQ0gUal/uIq0tOoTqqveuDOAKYLaky4EHgd8BSNoaeKPKsbmAhJqUndtEjoXyJY2WtEDSs5LOSFh+gKRHJTVIOjqvbyGtiNElku4GdgMuMrP58euvAwfkFYRzzuUmp75jSV2JGqOfBJYQNVKnmNlTJastAsYCp+dy0FhqV4aZzQPm5XlQ55yrFmvKrfN0FPCsmT0PIOkvwBhgQ2I2sxfjZbn2n/gNJs65+pJfH/P2RLXnmy0B9s1r562pemJ+q0vXah+iTcad+F7RISS6eVKYN+RMP2d50SGU9VaXMG8w+drwxekrFeCueUOKDqG6KhhtUfoYvNjEePBCobzF7JyrLxW0mEtHkCV4mei+jWaD49eqzhOzc66+5NeVMRsYJmkoUUI+Hjghr523JsyCsc4511Zm2adWd2MNwGnAncDTwPVmNk/SuZKOAJC0j6QlwDHAZZJyGShR9RbzS91V7UO0yTXXhtmXS5hd8nz+plyeX1wdCrN9sf7Ky4sOIdGb84uOoMpyvMHEzG4Hbm/x2k9K5mcTdXHkyrsynHP1Jb/hcoXxxOycqy8B18DIKq26XB9Jv5T0R0kntFj2+1a221DE6KFVC/OK1TnnUllTU+YpVGkt5quBhcBfgVMlfR44wczWAh8tt1HpEJRXPvbvVqMRJhW598Xtig4h0WvdwuyTv+zom5lwedlfeaHssUeKDiHRGw+H+YijxjBPsfx0gq6Mnczs8/H8TZLOBGY0X5F0nUeoSdm5TQRcZzmrtMTcU1IXs+g7NbOfS3oZmEVcAtQ554JSBy3mtHFGtwCfKH3BzK4Bvgusq1JMzjnXdg2N2adApZX9/H6Z16dK+kV1QnLOuXboBF0ZrTmH6OJgq5a/EmaPxwcIs4jR+9eHeWVm+Vk3Fh1CWT22CLPls3n/MD9S77M0zHM/N3XQldFqYpb0RLlFwMD8w3HOufYJeRhcVmkt5oHAp4A3W7wu4O9Vicg559qj3lvMwK1AbzN7rOUCSfdUJSLnnGuPek/MZjaulWWZyt/NbgjzadTHnfBO0SEkuuG6LYsOIdEQVhYdQll/fS73GjK5OGrokqJDSLSgS6AFvICP5bGTOrgl22tlOOfqSo7P/CuMJ2bnXH2pg8ScVsRodMl8X0lXSnpC0p8llR2VUVrE6N53vYiRc66GmpqyT4FKazH/Apgaz18ELAUOB44CLgOOTNqotIjRz95/or1CeO9gP5+8RdEhJHqp27tFh5Bo2mt9GLumZ9FhJFodZlisWx3mUw+mBXqOAXwpj53UQYu5kq6MkWY2Ip6/WNIp1QjIhSnUpOzcJjpBYt5G0neIxi33kSSzDQ/KCvN5Ps65Ts0aw+2iyCotMV8ONI/fmgRsBbwuaVtgk7HNzjlXuHpvMZvZOWVeXyZpZnVCcs65tuvsw+UyFTE6rteKdhyielav7lF0CIk+uXJB0SEkuvKizxQdQlmrz1pWdAiJthgY5hNMpi18qugQqqveE7MXMXLOdTg5djHHQ4YvAboCV5jZ+S2W9wT+AOwNrACOM7MX23tcL2LknKsr1pBPZpbUFfgd8ElgCTBb0hQzK/3IMQ5408x2lnQ8cAFwXHuP7UWMnHP1Jb8W8yjgWTN7HkDSX4AxQGliHgP8NJ6fDPxPi9FrbVL1IkYN68IcZN+jW5iFTmZts1PRISRa9Ouniw6hrN0HhDk8as0bYZ77oZ5jeank4p+kCcCEkpcmxjfIAWwPLC5ZtgTYt8UuNqxjZg2S3gLeByyvMOyNeK0M51x9qeB9uvQu5ZB4YnbO1ZUch8u9DAwp+Xpw/FrSOkskdQP6El0EbJe0IkYjJc2UdK2kIZKmSXpL0mxJe7ay3YYiRte/vai9MTrnXHZNFUytmw0MkzRUUg/geGBKi3WmAM3lKY4GZrS3fxnSW8y/B84G+hGNwvi2mX1S0sHxsv2SNir9eHDbwC/YC6vaG2b+euQ5piZnP+zyStEhJDrbwixIv1ZhVgeYvTrcc2z62jCL+M/JYR+W0/DxuM/4NOBOouFyV5nZPEnnAnPMbApwJfBHSc8CbxAl73ZLS8zdzewOAEkXmNnkOODpkn6dRwBuY56UXbWFmpTzYjm+H5rZ7cDtLV77Scn8GuCY/I4YSUvMayQdStRvYpKONLObJB0IhDmswTnXuYX7QSWztMT8VeBXRN/qp4CvSbqGqMP7y9UNzTnnKpdni7korXbOmdnjZvYpMzvMzOab2TfNrJ+ZDQc+WKMYnXMuM2vKPoWq6kWMugf49BKAdxTm4P8fsEPRIST66B7h9kv2/nEuz73I3V7/8eeiQ0g0cm2Y51herFFFh9BuXsTIOVdXQm4JZ+VFjJxzdcWa6rzFjBcxcs51MHXfYs6jiNGKrmH25a7qEua7api3SsCUp4fQtzHM6wXdjr2r6BASbWYDig4h0dvdwzz382LW8b8/r5XhMgk1KTvXUj20mNvcQJN0R56BOOdcHpoalXkKVdqojL3KLQJGtLLdhhqn4/qO4uBeO7c5QOecq0RnuPg3G7iXKBG31K/cRqVFjK7b7kT/DOycq5nOkJifBr5iZgtbLpC0OGH9TbzZNcwf0rbrw+yImtszzJ9Xn4Aro7zbJcxLph/fo2Xp3jA8PL++C1K1v+hm8dIS808p3w/9jXxDcc659quHFnNarYzJgCQdLKl3i8VrqheWc861jZkyT6FKe4LJfwI3E7WO50oaU7L4F9UMzDnn2qKxUZmnUKV1ZXwZ2NvMVknaEZgsaUczu4TkC4KbCLUvd9+hy4oOIdGxjywoOoRE92/V8uHA4Rg8ZGXRISRatnDLokNIdO7Se4oOoayfpK+SKuSWcFZpibmLma0CMLMXJR1ElJzfT8bE7JxztVT3fczAq5I2jFeOk/Rnga2APaoZmHPOtYVZ9ilUaS3mk4GNHm1oZg3AyZIuq1pUzjnXRvXQYk4rYlS2OrqZPZDlAKPPCLOf7cFfhjkw9/Bttyg6hEQXsJ5Ld3q76DASPfHUtkWHkKiXQj3Hyt3QWx8am8Ic116Jir8DSdtUIxAXtlCTsnMt1UNXRtpwuQEtpvcBD0vqLynMmobOuU6tyZR5ao84J06TtDD+v3+Z9aZKWinp1qz7TmsxLwceKZnmANsDj8bz5QKeIGmOpDlX/f3prLE451y71fAGkzOA6WY2DJgef53kQuCLlew4LTF/D1gAHGFmQ81sKLAknv9AuY3MbKKZjTSzkad+bLdK4nHOuXapYVfGGGBSPD8JODI5HpsOvFPJjtMu/l0k6f+Ai+OiRWdDZY+9Xnn9M5WsXjP3bbZd0SEkWh9ole9+l/yo6BDKeuqzfyw6hETPBXoNan1TmOdYXirpoigtURybGFfHzGKgmS2N55eR4wOqU59gEo/MOEbSEcA0oFdeB3fOubxVMiqjtERxEkl3A0nDfs5ssR+TlNvlxNTELGlXon7lGUSJeaf49dFmNjWvQJxzLg95DrYws0PKLZP0qqRBZrZU0iDgtbyOW1ERI+BQM5sbL/YiRs654NRqVAYwBTglnj+FKFfmoupFjLp0C3Ow4Ii1YcZ1qDYrOoRET33290WHUNY+68N8pvA+RQdQVpjnWF5qWMTofOB6SeOAl4BjASSNBL5qZuPjr+8DdgV6S1oCjDOzO1vbsRcxcs7VlVpd2jSzFcDBCa/PAcaXfL1/pfv2IkbOubpiKPMUKi9i5JyrKw31Xo85jyJGK5eFObquR6Djhb/VZXnRISRrgK9pSNFRJOoa6Hjh3o1hnmPndwnzIbEAD+ewj5BbwlmFedXEBSfUpOxcS2G+HVYmbbjco5LOkrRTrQJyzrn2qIc+5rQPgf2BfsBMSQ9L+rak1HuZS4sYXf/WolwCdc65LJoqmEKVlpjfNLPTzWwH4LvAMOBRSTPje8wTlRYxOrbvDnnG65xzrWpEmadQZe5jNrP7gPskfQP4JHAcrdxj3qz/dqvbHl0VbbO6R9EhJDp7fZjFlfb76OKiQygv0It/tq7oCJL1nBPmOZaXOniyVGpi3qQ0nJk1AlPjyTnngtIUcEs4q1bbGmZ2vKRdJR0sqXfpMkmjqxuac85VziqYQpU2KuMblBQxkjSmZLEXMXLOBaceLv6ldWVMoJ1FjF57KcynZD9pvdNXKkDvLmG+jzetKTqC8gK9V4jbngpz7PcWgZ5jeWlSx+/K8CJGzrm60lh0ADnwIkbOubrSpOxTqLyIkXOurtTDqIyqFzFasKZPpTHVxGb5PZ4rV+92CfOkunXuEIY0rC86jERrAh3IfOhOYRYLuuul7YsOoarC/MuujBcxcpmEmpSdaynkLoqsWk3MkroB44DPAc23C71MNITuSjPzv1bnXFACHaRTkbTPgH8ERgA/BT4dT+cAHwGuLbdRaRGjaaufzSlU55xL16jsU6jSujL2NrNdWry2BHhI0ia3azczs4nEdTQmDzqxHrp8nHMdRD20mNMS8xuSjgH+ahYN45fUBTgGeDPTASzMvNyrKcxf32aBnlb7XbpX0SGU1W3/Y4sOIdGs4T8sOoREOxBodaWchPkXVJm0rozjgaOBZZKeiVvJy4Cj4mXOORcUU/apPSQNkDRN0sL4//4J64yQ9KCkeZKekHRcln2nFTF6EfgvoptK9gO+BPwKmGRmL1T8nTjnXJXVsFbGGcB0MxsGTI+/bmk1cLKZDQdGA/8tqV/ajtNGZZwNHBavNw0YBdwDnCFpTzP7eSXfhXPOVVsNb8keAxwUz08iyo0/KF3BzJ4pmX9F0mvA1sDK1nac1sd8NNGojJ5EXRiDzextSb8G/gGkJuZtu4RZ/Wb4UWHGNeuGvkWHkGjdTdOKDqGsVb+/o+gQEq1SmDdyhHrdJy81HMc80MyWxvPLgIGtrSxpFNADeC5tx2mJuSEujL9a0nNm9jaAmb0nqR762J1zdaaSxBQ/Iq/0MXkT41FlzcvvBrZN2PTM0i/MzKTytxNLGkQ0/PiU5oEUrUlLzOsk9TKz1cDeJQfpS31c/HTO1ZlKElPp0N4yyw8pt0zSq5IGmdnSOPG+Vma9PsBtwJlm9lCWuNJGZRwQJ2VaZPnuwClZDuCcc7VUwyeYTOFfefAUojuiNyKpB3Aj8Aczm5x1x2lFjNaWeX05sDzLAd5uDPOhpw/eEGZcK7t2LTqERLdOG0S/xjAr3R54aqDjcp8sOoBkoZ5jealhH/P5wPWSxgEvAccCSBoJfNXMxsevHQC8T9LYeLuxZvZYazv2IkYuk1CTsnMt1epMNbMVwMEJr88Bxsfz19JK+Ypy0obL9QJOI2r1/5boppKjgPnAuc1PN3HOuVA01UHhz7Q+5muIhoAMJeq8HglcSPRYqf8tt1FpEaPb30sdGeKcc7npDA9j3cXMjpUkYClwSDws5H7g8XIblV7pvGvg8R3/7cs512HUQ8LJ1MccJ+PbzaKR6Wlj9kp9cPtM1whrToE+KfgjWzakr1SALY8aXnQIZS2ftKDoEBLtFeZDsunRO8xzLC8ht4SzSkvMcyT1NrNVZnZq84uSdgLeqW5ozjlXuYZAHxtXibThcuMljZJkZjZb0oeICnEsAPavSYTOOVeBjp+WKyhiJGkasC8wk6hQxwgy1Mpwzrla6gxdGe0uYnT3a0m3mRfv2M+G2fd91tStiw4h0ZYXtVoMq1Bn/WVC+koFWPHty4oOIdEvXhhUdAhlXZLDPuphuJwXMXLO1ZWOn5a9iJFzrs7UQ2JKS8wHNNfL8CJGzrmOoLEO2sxVL2LUM9C3rxfv7F50CIl2awy3fMmO68Ic/zpjzI1Fh5DIWq+bXpjdetR5EaOiA8hBuFnABSXUpOxcS1YHLeZWa2VIOk3SVvH8zpJmSVop6R+S9qhNiM45l1091MpIK2L0tbjbAqKRLBebWT+iccyXltuotIjRjNULcwrVOefSNWGZp1ClJebSro5tzOxGADO7B9iy3EZmNtHMRprZyE/0Gtb+KJ1zLqMaPsGkatL6mCdLugY4F7hR0reIHpPyCWBRlgOsr93TBCpyb0OYT6N+sWuYBemHBXwav9wtzAu5R/+oX9EhJPrJhSuKDqGqGgI+V7NKG5VxZvw4lOuAnYjuAJwA3AScWPXonHOuQvVw8S/LqIyngNPiIkbDiYoYPW1mb1U3NOecq1zIF/WyqrSI0SjgHuAMSXuamRcxcs4FpTO0mNtdxGhww/p2B1kNr3UNs19yl3VhnlTvKNybEvoH+qDY2897s+gQEh3aFOY5lpe6bzHjRYyccx1Mo3X8N5604XLr4idlgxcxcs51ALUaxyxpgKRpkhbG//dPWOf9kh6V9JikeZK+mmXfaYn5gLiynBcxcs51CFbBv3Y6A5huZsOA6fHXLS0F9jOzEUQPGjlD0nZpO656ESMjzIHM3QL9uLOsW7jlSz69++KiQ0i09Jk+RYeQaMfDwzzHbvxrmOOr81LDj/JjgIPi+UlEAyN+ULqCma0r+bIn6Y1hyLqSc6EmZedaqqQro7R8RDxV8jicgWa2NJ5fBsnlBCUNkfQEsBi4wMxeSdtx2nC5LsBY4PPAYKAReAa4NL4t2znnglJJF4WZTQQmllsu6W4g6fl4Z7bYj0nJj+c2s8XAh+MujJskTTazV1uLK+1z85XAS8AviYbOvQ3cB5wlaQ8z+22Zb2YC0R2CfHPLkXx6851SDuOcc/nIc1SGmR1SbpmkVyUNMrOlkgYBr6Xs6xVJc4H9gcmtrZvWlbG3mf3UzO43s28Bh5rZNOAzwNdbCWBDESNPys65Wqphdbkp/GsQxCnAzS1XkDRY0ubxfH/g48CCtB2ntZjXS9rJzJ6TtBewDqKLguWa7S2FWl59zPWjiw4h0X1H31p0CInuf3z7okMoa/SNhxcdQqKHP/e3okNINMjCvOkrLzW8+Hc+cL2kcUQ9C8cCSBoJfNXMxgO7ARfF+VLAr83sybQdpyXm7wEzJa2N1z0+PvDWQJgZxDnXqdXqlmwzWwEcnPD6HGB8PD8N+HCl+04bLjdD0nFEdwDOlvQhSd8B5pvZ9ys9mHPOVVvIBfCz8iJGzrm6YoHeo1CJqhcxCvUGk6bZs4oOIVHPUEuQBHyuzzpyk2suQXhPYRbK6q/67mNuDPlkzciLGDnn6krdd2UQFzGK62V4ESPnXPA6Q1fGAc31MryIkXOuI6j7FnMeRYycc66WOsMTTNpt3xGp9ToKYW8m1hsp3MJuPYsOIdFhO7xcdAhlyUtxVeT2F8O9WeiAHPZR94XyJXWV9BVJP5P0by2WnVXd0JxzrnI1vCW7atLaGpcBBwIrgN9I+q+SZUeV26i0lN4fliwtt5pzzuWuMyTmUWZ2gpn9N1H1/d6S/iapJ5QfoFxaxOjkwYPyjNc551plZpmnUKX1MfdonjGzBmBCfDfgDKB3lgPcMndI26Oroj3nv110CImusRVFh5Bo2PNbFR1CWR/c9fWiQ0h0+gsDig4h0YsW7qfYcTnsI+SWcFZpLeY5kjYqw2Zm5wBXAztWKyjnnGurGj7zr2rShsud1PI1SX8ws5OBK6oWlXPOtVGjdfx739KKGE1p+RLw75L6AZjZEdUKzDnn2iLkvuOs0vqYhwDziFrHzYWeRwIXZT3AJwYua3Nw1fTuW2GOF/7LdqGeVCtYuyrMJ3ivXdW16BASndvvnaJDSNSjV6jnWD46Qx/z3sAjRA8efCt+AOt7Znavmd1b7eBcOEJNys611Bn6mJuAiyXdEP//ato2zjlXpKZO0JUBgJktAY6R9BmiJ2U751yQQm4JZ1VR69fMbgNuq1IszjnXbnU/KiMPXbuH+UN6eE3/okNItijMuI6/6XNFh1DWtMMnFx1CohFDXis6hETTFm1XdAhljc1hH/XQleF1uZxzdaVWF/8kDZA0TdLC+P+yrSpJfSQtkfQ/WfadVl3uwyXz3SWdJWmKpF9I6tXKdhuKGP3ptTDLfjrn6lOTWeapnc4AppvZMGB6/HU5PwMyP2g0rcV8Tcn8+cDORGOYNwcuLbdRaRGjE7cJ92OTc67+1HC43BhgUjw/CTgyaSVJewMDgbuy7jitj7m0gtzBwD5mtl7SLODxLAdoagzzKdmfGhpm4fce/cPsk1911m+KDqGsQ+/6UdEhJHrne78oOoREh2+7uOgQqqrRGmt1qIFmGypCLSNKvhuR1IWoMXsScEjWHacl5r6SPkfUsu5pZusBzMwkdfwedudc3anklmxJE4AJJS9NNLOJJcvvBrZN2PTMFscslxO/DtxuZkuk7I3UtMQ8C2iuh/GQpIFm9qqkbfFn/jnnAlTJLdlxEp7YyvKyrVxJr0oaZGZLJQ0Ckobh7AfsL+nrRKWSe0haZWat9Uen3vk3NiGY5upyB7e2rXPOFaGGRYymAKcQXX87Bbg5IZYTm+cljQVGpiVlqLy6HMAnKqku974jw3yCydeveK/oEBLd8vgTRYdQ1rcHjCo6hERvHHF10SEk+s7WYVYv+Mj8RUWHUNYbOeyjhuOYzweulzQOeAk4FkDSSOCrZja+rTtuS3W5faigupyrD6EmZedaqtUt2Wa2goSeAzObA2ySlM3sGjYe6VaWV5dzztWVRmvKPIXKq8s55+pKZyiUD3h1Oedcx1EPtTI6bXW5497rXnQIiU7utVfRISTabago/mwAAAg+SURBVPtwb63v++9hPsH7gSu2LjqERDf0CvPnlZd6aDGn1cr4gKSrJJ0nqbekyyXNlXSDpB1rE6JzzmXXhGWeQpWlVsZsYBXwEDAfOAyYClxVbqPSIkZXzV6YU6jOOZfOzDJPoUpLzFua2f+a2flAHzO7yMwWm9mVQNkSd6VFjE7dZ1iuATvnXGvqflQG0CRpF6Av0EvSSDObI2lnINOjiR/4bUN7Y6yKRT3CHFyyw7owf17vmxhmoSCAVT88r+gQEimxxELx1hJmYbG8dIaLf98HbgGaiEra/TCu0dyXjQt/OOdcEELuosgqbRzzdOCDJS/dL+lW4Ih4jLNzzgWl7h/GWqZWxkHATZIy1cpwzrlaqvsWM14rwznXwdRDH7Nae3eJq+9/E/g08D0ze0zS82b2gVoF2CKeCaVFrEMSamweV2VCjQvCjS3UuDqyVhPzhpWkwcDFwKtE/cs7VDuwMnHMMbORRRw7TaixeVyVCTUuCDe2UOPqyLxWhnPOBabT1spwzrlQpd35F5qQ+7FCjc3jqkyocUG4sYUaV4eVqY/ZOedc7XS0FrNzztW9oBOzpJ9KOr3oOEIk6UVJW5V8fVB8VyaSxkpqim+fb14+t7lUa+m2kvaW9IKkPasQ47ck9WrjtoX/7kt/phnXHytpu2rGVEsd4RyrV0EnZrcxST0kbZFx9SVEz2psbX8fBiYDx5nZPyX1jceu5+VbQJsSc9EktaXK1VigQyfmDniO1aXgfkCSzpT0jKT7iet0SPqypNmSHpf0V0m9JG0Zvwt3j9fpU/p1jvHsKOnp+CEB8yTdJWlzSTtJmirpEUn3SdpVUtc4BknqJ6lR0gHxfmZJalMNVEm7SboIWADsknGzW4Hhkj5YZvluwE3AF83s4fi1jwML4tZqRWPVJW0h6bb4dzRX0tlESWqmpJnxOl+Q9GS8/IKSbUdLejTednrCvr8s6Q5Jm2eIY0dJ8yVdE59Hf5J0iKQHJC2UNCqeHpT0T0l/b/4Zxa3AKZJmANNb7HefeP2d4hbgvfHv/k5JgyQdDYwE/iTpsSyxpnwfP5a0QNL9kq6TdLqkEZIekvSEpBsllS2924bjBX+OdSqVFJWu9kT0VO4niVpZfYBngdOB95Wscx7wjXj+auDIeH4CcFEVYtoRaABGxF9fD5xE9Ic7LH5tX2BGPD8VGA58lughA2cCPYEXKjzuFsCXgPvjaRxRfezm5S8CW5V8fRBwazw/Fvgf4GRgUvzaXGDHkm3fAD6dcNytgG8Dj8XfyzFAjwzxfh64vOTrvqUxEiXpRcDWRMM0ZxBVLNwaWAwMjdcbEP//0/h3fxpwM9Czwt/XHkQNj0eIHuogYAxRougDdIvXPwT4a8nPbUlJDAcRJZ+PxfvZAegO/B3YOl7nOOCqeP4eYGQO59w+8c9/M2BLYGH8s3gCODBe51zgv9t5nA51jnWmKbSixPsDN5rZatioiNLuks4D+gG9gTvj168gKk16E9EJ9uUqxfWCmT0Wzz9C9Mf/MeAGaUNt257x//cBBwBDgV/GMd1LlKQrsZToD3G8mc1PWJ40nKbla38GzpQ0NGHdu4Hxku40s8YNOzBbTnSX58WS9iNKaj8GPpywj1JPAhfFLeFbzey+kp8NRMnmHjN7HUDSn4h+To3ALDN7IT7+GyXbnEyUtI80s/Upxy/1gpk9GR9nHjDdzEzSk0S/u77ApPgTjBEl22bTWsSwG9FwsEPN7BVJuwO7A9Pi768r0e8qT/8G3Gxma4A1km4hSqL9zOzeeJ1JwA3tPE5HO8c6jeC6Msq4BjjNzPYAziFqSWBmDwA7SjoI6Gpmc6t0/LUl843AAGClmY0omXaLl88ieoMZBdxO9GZyEFHCrsTRwMvA3yT9RNL7WyxfwcZPkRkALC9dwcwaiApO/SBh/6fF//++5QJJH5J0IfAH4AEyvOGZ2TPAXkQJ+jxJP0nbJoPmRDq4wu1Kf19NJV83EbXWfwbMNLPdgcOJz6fYuy32tRRYAzRfuBIwr+T3voeZHVphfKHoUOdYZxJaYp4FHBn34W5J9EcD0ce5pXH/8YkttvkD0bv21bULk7eBFyQdAxD3KX8kXvYwUWu6KW7xPAZ8heh7y8zM7jKz44iS/FvAzZLu1r8egnsP8MX4+F2JuldmJuzqGqKP6y0f2dwEnADsKunceD97SXqI6JPIfGBPMxtvZv9Ii1fRaITVZnYtcCFRkn6H6HcH0c/lQElbxfF+geiTxEPAAc0tLkkDSnb7T6Kf3RTlO9qhL1FCgugjeWtWAp8Bfhk3ABYAW8ctPSR1lzQ8Xrf0+22PB4DDJW0mqTdRt9i7wJuS9o/X+SLRz6/NOto51pkElZjN7FHg/4DHgTv418f/HwP/IDphW37k+hPRu/p1NQqz2YnAOEmPE5VGHQNgZmuJPn4/FK93H9Ef65NtOYiZrTCzS8xsBPAjohY7RK2+nePj/5OoP/7ahO3XAb8BtklYtgY4AjhC0n8A7wFfMrOPmdmVZraqglD3AB6W9BhwNtG1gInAVEkzzWwpcAbRH/bjwCNmdnPctTGBqNX2ONHvvzTG+4n6V29TydCtdvoVUaL9JxnKEpjZq0TJ8XdELeejgQvieB8jeiOGKEFd2t6Lf2Y2G5hC1M1wB9G58xZwCnChpCeAEUT9zO3Wgc6xTqPD3/kXXw0fY2ZfLDoW5/IiqbeZrVI0DnwWMCFuuLhOILSLfxWR9FvgMKJ60c7Vk4mSPkTU/z3Jk3Ln0uFbzM45V2+C6mN2zjnnidk554Ljidk55wLjidk55wLjidk55wLjidk55wLz/xlh3esl+SLIAAAAAElFTkSuQmCC\n","text/plain":["<Figure size 432x288 with 2 Axes>"]},"metadata":{"tags":[],"needs_background":"light"}}]},{"cell_type":"markdown","metadata":{"id":"NWakIG-C6Ybh"},"source":["The filters have high values for the words `stock` and `market` which influenced the `Business` category classification."]},{"cell_type":"markdown","metadata":{"id":"3UsWNPkaHGUZ"},"source":["> This is a crude technique (maxpool doesn't strictly behave this way on a batch) loosely based off of more elaborate [interpretability](https://arxiv.org/abs/1312.6034) methods."]}]}